import torch
import torch.nn.functional as F 
from torch.utils.data import DataLoader
import h5py
import os
import numpy as np
import operator
from itertools import accumulate
import copy
from utils import NodeType
import math

previous_face_edges = []
frame_index = [0]
def custom_collate(data):
    final_flag = [torch.tensor(d['final_flag']) for d in data]
    world_pos = [torch.tensor(d['world_pos']) for d in data]
    stress = [torch.tensor(d['stress']) for d in data]
    node_type = [torch.tensor(d['node_type']) for d in data]
    ground_truth = [torch.tensor(d['target']) for d in data]
    next_pos_need = [torch.tensor(d['next_pos_need']) for d in data]
    ori_faces = [torch.tensor(d["faces"]) for d in data] 
    ori_elements = [torch.tensor(d["elements"]) for d in data] 
    ori_faces_edges = [torch.tensor(d["faces_edges"]) for d in data]
    ori_faces_to_faces = [torch.tensor(d["faces_to_faces"]) for d in data]
    last_pos = [torch.tensor(d['last_pos']) for d in data]
    last_stress = [torch.tensor(d['last_stress']) for d in data]
    ori_cells_faces = [torch.tensor(d['cells_faces']) for d in data]
    ori_mesh_edge = [torch.tensor(d['mesh_edge']) for d in data]
    mesh_pos = [torch.tensor(d['mesh_pos']) for d in data]
    ori_world_edge = [torch.tensor(d['world_edge']) for d in data]

    final_flag = torch.concat(final_flag, dim = 0)
    mesh_pos = torch.concat(mesh_pos, dim = 0)
    world_pos = torch.concat(world_pos, dim = 0)
    stress = torch.concat(stress, dim = 0)
    node_type = torch.concat(node_type, dim = 0)    
    ground_truth = torch.concat(ground_truth, dim = 0)
    next_pos_need = torch.concat(next_pos_need, dim = 0)
    last_pos = torch.concat(last_pos, dim = 0)
    last_stress = torch.concat(last_stress, dim = 0)
        
    node_count = [d["stress"].shape[0] for d in data]
    element_count = [d['elements'].shape[0] for d in data]
    face_count = [d['faces'].shape[0] for d in data]
    node_count = torch.tensor([0] + list(accumulate(node_count))[:-1]) 
    element_count = torch.tensor([0] + list(accumulate(element_count))[:-1]) 
    face_count = torch.tensor([0] + list(accumulate(face_count))[:-1]) 
    for i in range(len(ori_elements)):
        ori_world_edge[i] = ori_world_edge[i] + node_count[i]
        ori_mesh_edge[i] = ori_mesh_edge[i] + node_count[i]
        ori_elements[i] = ori_elements[i] + node_count[i]
        ori_faces[i] = ori_faces[i] + node_count[i]
        ori_faces_edges[i] = ori_faces_edges[i] + element_count[i]
        ori_faces_to_faces[i] = ori_faces_to_faces[i] + face_count[i]
        ori_cells_faces[i] = ori_cells_faces[i] + face_count[i]
    elements = torch.concat(ori_elements, dim = 0)
    faces = torch.concat(ori_faces, dim = 0)
    faces_edges = torch.concat(ori_faces_edges, dim = 0)
    faces_to_faces = torch.concat(ori_faces_to_faces, dim = 0)
    cells_faces = torch.concat(ori_cells_faces, dim = 0)
    world_edge = torch.concat(ori_world_edge, dim = 0)
    mesh_edge = torch.concat(ori_mesh_edge, dim = 0)

    return  { 
                "final_flag": final_flag,
                "cells_faces":cells_faces,
                "last_stress":last_stress,
                "last_pos":last_pos,
                "world_pos": world_pos,
                "stress": stress,
                "node_type": node_type,
                "target": ground_truth,
                "next_pos_need": next_pos_need,
                "elements": elements,
                "faces": faces,
                "faces_edges": faces_edges,
                "faces_to_faces": faces_to_faces,
                "mesh_edge": mesh_edge,
                "mesh_pos": mesh_pos, 
                "world_edge": world_edge,
            }
    
def GetDataSetDP(dataset_dir, split, batch_size = 1):
    file_path = os.path.join(dataset_dir, split + ".h5")
    frame_data_collected = []
    cnt = 0
    cnter = 0
    input_file_path = os.path.join(dataset_dir, f"data_{split}_trace99.h5")

    with h5py.File(input_file_path, 'r') as f:    
        with h5py.File(file_path, 'r') as file:
            for trace_id in file.keys():

                print(trace_id)
                group = f[str(trace_id)]
                trace_data = file[trace_id]
                trace_length = trace_data["stress"].shape[0]
                
                step_slice = 4
                last_cnt = cnt
                real_slice = 4
                if (split == "test"):
                    real_slice = step_slice
                for frame_id in range(0, trace_length - step_slice, real_slice):
                    if (frame_id >= 0):
                        data_set = group[str(frame_id)]
                                                            
                        frame_data_collected.append(
                            { 
                            "final_flag": np.array([frame_id + real_slice >= trace_length - step_slice]),
                            "mesh_pos": np.array(data_set["mesh_pos"]),
                            "mesh_edge": np.array(data_set["mesh_edge"]),
                            "cells_faces": np.array(data_set["cells_faces"]),
                            "last_stress": np.array(data_set["stress"]).astype(np.float32),
                            "last_pos": np.array(data_set["last_pos"]).astype(np.float32),
                            "world_pos": np.array(data_set["world_pos"]),
                            "stress": np.array(data_set["stress"]),
                            "node_type": np.array(data_set["node_type"]),
                            "elements": np.array(data_set["elements"]),
                            "faces": np.array(data_set["faces"]),
                            "world_edge": np.array(data_set["world_edge"]),
                            "faces_edges": np.array(data_set["faces_edges"]),
                            "next_pos_need": np.array(data_set["next_pos_need"]).astype(np.float32),
                            "faces_to_faces": np.array(data_set["faces_to_faces"]),
                            "target": np.array(data_set["target"])
                           }
                        )  
                cnter = cnter + 1

    dataset = DataLoader(dataset = frame_data_collected, batch_size = batch_size, shuffle = False, collate_fn=custom_collate)
    return dataset
if __name__ == '__main__':
    dataset = GetDataSetDP("/meshgraphnet_data/deepmind_h5/deforming_plate/", "test")
    
    
